from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from .Mix import MixedDataset
from .Expression import ExpressionDataset


def prepare_mix(arg):
    transform_train = transforms.Compose([
        # Resnet18 224*224 input
        transforms.Resize((224, 224)),
        transforms.ToTensor(),

        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),

        # ImageNet
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

        transforms.RandomErasing(scale=(0.02, 0.1)),
    ])

    transform_test = transforms.Compose([
        # Resnet18 224*224 input
        transforms.Resize((224, 224)),
        transforms.ToTensor(),

        # ImageNet
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_set = MixedDataset(phase='train', arg=arg, transform=transform_train)
    test_set = MixedDataset(phase='test', arg=arg, transform=transform_test)

    train_loader = DataLoader(train_set,
                              batch_size=arg.batch_size, shuffle=True, num_workers=4,
                              # sampler=ImbalancedDatasetSampler(
                              #     train_set,
                              #     callback_get_label=lambda x: x.targets
                              # )
                              )
    test_loader = DataLoader(test_set, batch_size=arg.test_batch_size, shuffle=False, num_workers=4)

    return train_loader, test_loader


def prepare_train(arg):
    return prepare_mix(arg)


def prepare_test(arg):
    transform_test_raf = transforms.Compose([
        # Resnet18 224*224 input
        transforms.Resize((224, 224)),
        transforms.ToTensor(),

        # Not sure, copied from opensource project
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    test_set = ExpressionDataset(test_set=arg.test_set, transform=transform_test_raf)
    test_loader = DataLoader(test_set, batch_size=arg.test_batch_size, shuffle=False, num_workers=4)

    return test_loader
